FP8 kv cache quantization#4563
Draft
CUHKSZzxy wants to merge 16 commits intoInternLM:mainfrom
Draft
Conversation
Adds FP8 KV cache quantization (QuantPolicy.FP8 = 16) using torch.float8_e4m3fn with per-token symmetric scale (no zero point). Key design: - Reuses existing fill_kv_cache_blocked_fp8() with group_size=head_dim for per-token scale semantics in the fill path - Dequant in flatten_kv_cache and paged_attention via x.to(f32)*scale - Scale tensor shape [..., 1]: symmetric, no zero point - No bit packing (head_dim unchanged, unlike INT4/TURBO_QUANT) Also fixes pre-existing TestFillKVCacheBlockedFP8 test failures caused by calling .max() on float8_e4m3fn tensors (cast to float32 first). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Avoid constructing a temporary cu_seqlen_q tensor in the FP8 cache-fill path by letting fill_kv_cache_blocked_fp8 consume the existing q_start_loc and q_seq_length metadata directly. The kernel keeps the old cumulative-seqlen mode for direct callers via a USE_CU_SEQLEN constexpr. Move default paged-decode FP8 dequant scaling across the attention dot products. K scales are applied after QK, and V scales are applied to probabilities before PV, which preserves the per-token/head scale algebra while avoiding full K/V tile dequantization in the hot decode loop. Add a focused FP8 paged-attention test that compares against a dequantized-FP8 reference, including a split-head-dim case, so the fused scale placement is covered without conflating it with expected quantization error.
Split normal FP8 KV cache from the dynamic per-token/head FP8 path. Normal fp8/fp8_e4m3 and fp8_e5m2 now use scalar K/V scales with FP8 cache tensors and no k_scales_zeros/v_scales_zeros metadata allocation, while fp8_per_token_head variants keep the existing per-token/head scale-cache behavior. Thread scalar k_scale/v_scale through PyTorch attention dispatch, cache fill, flatten, and paged decode kernels so normal FP8 can quantize on cache write and apply scalar dequant in decode/prefill without materialized metadata tensors. Add optional one-shot calculate_kv_scales support and guard CUDA graph capture while scale calculation is pending, mirroring vLLM's eager first-pass behavior. Add focused CLI/config/cache descriptor tests and scalar/per-token FP8 kernel reference coverage. Validation: py_compile on changed runtime/kernel/test files; pytest -q tests/test_lmdeploy/test_fp8_kv_cache_policy.py; git diff --check. CUDA kernel tests were not run because nvidia-smi cannot communicate with the driver in this environment.
Remove the deprecated-style dynamic KV scale calculation path and keep normal FP8 on the vLLM-aligned scalar-scale behavior with default scales. Drop the experimental per-token/head FP8 policy and tests so the public surface only exposes fp8, fp8_e4m3, and fp8_e5m2. Sadly we have to remove some potentially useful features to keep this PR concise and solid.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds PyTorch CUDA/Triton FP8 KV-cache support with a concise public policy surface aligned with vLLM-style normal FP8 behavior.
Supported policies:
fp8/fp8_e4m3-> scalar-scaletorch.float8_e4m3fnKV cachefp8_e5m2-> scalar-scaletorch.float8_e5m2KV cacheThe implementation intentionally does not expose per-attention-head / per-token-head FP8 KV-cache modes or dynamic
calculate_kv_scales-style calibration. Those paths were removed from the public surface to keep this PR focused on the normal scalar-scale FP8 path.What Changed
Public API / Config
QuantPolicy.FP8for E4M3 scalar-scale FP8 KV cache.QuantPolicy.FP8_E5M2for E5M2 scalar-scale FP8 KV cache.--quant-policy fp8defaults to E4M3--quant-policy fp8_e4m3maps to E4M3--quant-policy fp8_e5m2maps to E5M2Cache Allocation
torch.float8_e4m3fnortorch.float8_e5m2.Attention Runtime
k_scale/v_scalebuffers on PyTorch attention layers.1.0.k_scale=v_scale=1.0.Kernels
value / scalestored * scalek_scalein QK score computationv_scalebefore the PV accumulationTests
Scope
This PR focuses on the LMDeploy PyTorch CUDA/Triton attention path.
Out of scope:
Rationale
The previous experimental direction included per-token/head scale metadata and dynamic scale calculation. That made the implementation heavier and less aligned with the common normal FP8 KV-cache path used by vLLM.
This PR keeps the first upstreamable FP8 KV-cache feature smaller and easier to validate:
Test Plan